{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "visualize_speech_feature.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "",
        "kind": "local"
      }
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WKFhhlKKPEMt"
      },
      "source": [
        "Copyright 2020 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iHWimjnwPIJO"
      },
      "source": [
        "!git clone https://github.com/google-research/google-research.git"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "s2LJTsGKPInr"
      },
      "source": [
        "import sys\n",
        "sys.path.append('./google-research')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ukEYOjN6PO18"
      },
      "source": [
        "# Example with speech feature visualization"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oVX1y4ZH0JH3",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1607123753410,
          "user_tz": 480,
          "elapsed": 102,
          "user": {
            "displayName": "Oleg Rybakov",
            "photoUrl": "",
            "userId": "04792887722985073803"
          }
        }
      },
      "source": [
        "import numpy as np\n",
        "from matplotlib import pylab as plt\n",
        "import scipy.io.wavfile as wav\n",
        "import scipy as scipy\n",
        "\n",
        "from kws_streaming.layers import modes\n",
        "from kws_streaming.layers import speech_features\n",
        "from kws_streaming.layers import test_utils\n",
        "from kws_streaming.layers.compat import tf\n",
        "from kws_streaming.models import model_params"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qHpFDG8kk3ph",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1607123754262,
          "user_tz": 480,
          "elapsed": 138,
          "user": {
            "displayName": "Oleg Rybakov",
            "photoUrl": "",
            "userId": "04792887722985073803"
          }
        },
        "outputId": "19f76830-0dbc-4e47-d4ce-5eb48c60ef6e"
      },
      "source": [
        "tf.compat.v1.enable_eager_execution()\n",
        "tf.executing_eagerly()"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "True"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 3
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eA8Ol_SF-Be-",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1607123755824,
          "user_tz": 480,
          "elapsed": 86,
          "user": {
            "displayName": "Oleg Rybakov",
            "photoUrl": "",
            "userId": "04792887722985073803"
          }
        }
      },
      "source": [
        "def waveread_as_pcm16(filename):\n",
        "  \"\"\"Read in audio data from a wav file.  Return d, sr.\"\"\"\n",
        "  with tf.io.gfile.GFile(filename, 'rb') as file_handle:\n",
        "    sr, wave_data = wav.read(file_handle)\n",
        "  # Read in wav file.\n",
        "  return wave_data, sr\n",
        "\n",
        "def wavread_as_float(filename, target_sample_rate=16000):\n",
        "  \"\"\"Read in audio data from a wav file.  Return d, sr.\"\"\"\n",
        "  wave_data, sr = waveread_as_pcm16(filename)\n",
        "  desired_length = int(round(float(len(wave_data)) / sr * target_sample_rate))\n",
        "  wave_data = scipy.signal.resample(wave_data, desired_length)\n",
        "\n",
        "  # Normalize short ints to floats in range [-1..1).\n",
        "  data = np.array(wave_data, np.float32) / 32768.0\n",
        "  return data, target_sample_rate"
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YElWstqgCRHA",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1607123760925,
          "user_tz": 480,
          "elapsed": 129,
          "user": {
            "displayName": "Oleg Rybakov",
            "photoUrl": "",
            "userId": "04792887722985073803"
          }
        }
      },
      "source": [
        "def speech_feature_model(input_size, p):\n",
        "  speech_params = speech_features.SpeechFeatures.get_params(p)\n",
        "  mode = modes.Modes.TRAINING\n",
        "  inputs = tf.keras.layers.Input(shape=(input_size,), batch_size=p.batch_size, dtype=tf.float32)\n",
        "  outputs = speech_features.SpeechFeatures(speech_params, mode, p.batch_size)(inputs)\n",
        "  model = tf.keras.models.Model(inputs, outputs)\n",
        "  return model"
      ],
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kKV_l9Ei0WAW",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1607126156195,
          "user_tz": 480,
          "elapsed": 190,
          "user": {
            "displayName": "Oleg Rybakov",
            "photoUrl": "",
            "userId": "04792887722985073803"
          }
        }
      },
      "source": [
        "params = model_params.Params()\n",
        "params.window_size_ms = 25.0\n",
        "params.window_stride_ms = 10.0\n",
        "params.preemph = 0.97\n",
        "params.use_spec_augment = 0\n",
        "params.use_spec_cutout = 0\n",
        "params.use_tf_fft = 0\n",
        "params.time_shift_ms = 0.0\n",
        "params.sp_time_shift_ms = 0.0\n",
        "params.resample = 0.0\n",
        "params.sp_resample = 0.0\n",
        "params.train = 0\n",
        "params.batch_size = 1\n",
        "params.mode = modes.Modes.NON_STREAM_INFERENCE\n",
        "params.data_stride = 1\n",
        "params.data_frame_padding = None\n",
        "params.fft_magnitude_squared = False"
      ],
      "execution_count": 189,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2mD5jGyI0iXS",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1607126158041,
          "user_tz": 480,
          "elapsed": 114,
          "user": {
            "displayName": "Oleg Rybakov",
            "photoUrl": "",
            "userId": "04792887722985073803"
          }
        }
      },
      "source": [
        "frame_size = int(\n",
        "    round(params.sample_rate * params.window_size_ms / 1000.0))\n",
        "frame_step = int(\n",
        "    round(params.sample_rate * params.window_stride_ms / 1000.0))"
      ],
      "execution_count": 190,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oChEQHHs96QR"
      },
      "source": [
        "# wave_filename = \"test_speech.wav\"\n",
        "# waveform_data, sr = wavread_as_float(wave_filename)\n",
        "\n",
        "samplerate = 16000\n",
        "data_size = 51200\n",
        "test_utils.set_seed(1)\n",
        "frequency = 1000\n",
        "waveform_data = np.cos(2.0*np.pi*frequency*np.arange(data_size)/samplerate) * 2 + np.random.rand(data_size) * 0.4"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Yp_KR7RO-2OV",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1607126159162,
          "user_tz": 480,
          "elapsed": 142,
          "user": {
            "displayName": "Oleg Rybakov",
            "photoUrl": "",
            "userId": "04792887722985073803"
          }
        }
      },
      "source": [
        "signal = np.expand_dims(waveform_data, axis=0)\n",
        "data_size = signal.shape[1]"
      ],
      "execution_count": 192,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yNfwaNftGjNR"
      },
      "source": [
        "## Speech feature extractor: Data framing + Preemphasis + Windowing + DFT + Mel + log (no DCT: dct_num_features=0)\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-f1gbju7CFyO",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1607127233090,
          "user_tz": 480,
          "elapsed": 593,
          "user": {
            "displayName": "Oleg Rybakov",
            "photoUrl": "",
            "userId": "04792887722985073803"
          }
        }
      },
      "source": [
        "params.mel_num_bins = 80\n",
        "params.dct_num_features = 0  # no DCT\n",
        "params.feature_type = 'mfcc_tf'\n",
        "params.use_tf_fft = False\n",
        "params.mel_non_zero_only = False\n",
        "params.mel_upper_edge_hertz = 4000\n",
        "\n",
        "model1 = speech_feature_model(data_size, params)"
      ],
      "execution_count": 230,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SMjKBuYFKQjD"
      },
      "source": [
        "model1.layers[1].mag_rdft_mel.real_dft_tensor.shape"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hDX4kORFKUSi"
      },
      "source": [
        "mel_table1 = model1.layers[1].mag_rdft_mel.mel_weight_matrix.numpy()\n",
        "mel_table1.shape"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3LK9erqc8aSw"
      },
      "source": [
        "out1 = model1.predict(signal)\n",
        "plt.figure(figsize=(20, 5))\n",
        "plt.imshow(np.transpose(out1[0]))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lWKHYgs5tiMR"
      },
      "source": [
        "plt.figure(figsize=(20, 5))\n",
        "for i in range(mel_table1.shape[1]):\n",
        "  plt.plot(mel_table1[:, i])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IZ5FmgxM7zx7",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1607126165762,
          "user_tz": 480,
          "elapsed": 183,
          "user": {
            "displayName": "Oleg Rybakov",
            "photoUrl": "",
            "userId": "04792887722985073803"
          }
        }
      },
      "source": [
        "# It makes sense to set it True only if params.mel_upper_edge_hertz is much smaller than 8000\n",
        "# then DFT will be computed only for frequencies which are non zero in mel spectrum - it saves computation\n",
        "params.mel_non_zero_only = True\n",
        "\n",
        "model2 = speech_feature_model(data_size, params)"
      ],
      "execution_count": 198,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kE6XS59tC3Kk"
      },
      "source": [
        "model2.layers[1].mag_rdft_mel.real_dft_tensor.shape"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AygsA3qCC3NX"
      },
      "source": [
        "mel_table2 = model2.layers[1].mag_rdft_mel.mel_weight_matrix.numpy()\n",
        "mel_table2.shape"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DQNZowLQC3QT",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1607126166220,
          "user_tz": 480,
          "elapsed": 164,
          "user": {
            "displayName": "Oleg Rybakov",
            "photoUrl": "",
            "userId": "04792887722985073803"
          }
        }
      },
      "source": [
        "out2 = model2.predict(signal)"
      ],
      "execution_count": 201,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "N__shtNeC_Rl"
      },
      "source": [
        "plt.figure(figsize=(20, 5))\n",
        "plt.imshow(np.transpose(out2[0]))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "weKVO7v1DHm6"
      },
      "source": [
        "plt.figure(figsize=(20, 5))\n",
        "for i in range(mel_table2.shape[1]):\n",
        "  plt.plot(mel_table2[:, i])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2Ld1QYFcDTkl"
      },
      "source": [
        "np.allclose(out1, out2, atol=1e-06)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "utQx6LpKN_o7"
      },
      "source": [
        "## Compare mfcc_tf with mfcc_op"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-z8x4ADHEV3I",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1607126783553,
          "user_tz": 480,
          "elapsed": 645,
          "user": {
            "displayName": "Oleg Rybakov",
            "photoUrl": "",
            "userId": "04792887722985073803"
          }
        }
      },
      "source": [
        "params.mel_num_bins = 80\n",
        "params.dct_num_features = 20\n",
        "params.feature_type = 'mfcc_tf'\n",
        "params.use_tf_fft = False\n",
        "params.mel_non_zero_only = False\n",
        "params.fft_magnitude_squared = False\n",
        "params.mel_upper_edge_hertz = 4000\n",
        "params.preemph = 0.0  # mfcc_op des not have preemphasis\n",
        "\n",
        "model3 = speech_feature_model(data_size, params)"
      ],
      "execution_count": 223,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SEMCrlw0EV6K"
      },
      "source": [
        "out3 = model3.predict(signal)\n",
        "plt.figure(figsize=(20, 5))\n",
        "plt.imshow(np.transpose(out3[0]))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IUi8ZN9TEcSr",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1607126838834,
          "user_tz": 480,
          "elapsed": 106,
          "user": {
            "displayName": "Oleg Rybakov",
            "photoUrl": "",
            "userId": "04792887722985073803"
          }
        }
      },
      "source": [
        "params.feature_type = 'mfcc_op'\n",
        "# it will call two functions:\n",
        "# 1 audio_spectrogram computes hann windowing,\n",
        "#   then FFT - magnitude has to be squared\n",
        "#   because next function - mfcc computes sqrt (it assumes magnitude is squared)\n",
        "# 2 mfcc - compute mel spectrum from the squared-magnitude FFT input by taking the\n",
        "# square root, then multiply it with mel table then apply log and compute DCT\n",
        "\n",
        "params.fft_magnitude_squared = True\n",
        "model4 = speech_feature_model(data_size, params)"
      ],
      "execution_count": 228,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bpBLThInElQE"
      },
      "source": [
        "out4 = model4.predict(signal)\n",
        "plt.figure(figsize=(20, 5))\n",
        "plt.imshow(np.transpose(out4[0]))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "az3Vn9mZEr-_"
      },
      "source": [
        "# Features extracted with 'mfcc_op' are numerically different from 'mfcc_tf'\n",
        "np.allclose(out3, out4, atol=1e-6)"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}